#!/usr/bin/env python3

from functools import partial
from typing import Callable, Literal, Type

import torch
import torch.nn as nn


from modern_hopfield_attention.layers import SelfAttention, Mlp, PatchEmbed

#!/usr/bin/env python3
from typing import Callable, Literal, Type, Optional
from torch import Tensor
import torch.nn as nn
from timm.models.vision_transformer import (
    get_norm_layer,
    get_act_layer,
    Block,
    Mlp,
    PatchEmbed,
    LayerType,
)
from timm.models import VisionTransformer


class _CustomBlock(Block):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4,
        qkv_bias: bool = False,
        qk_norm: bool = False,
        proj_drop: float = 0,
        attn_drop: float = 0,
        init_values: float | None = None,
        drop_path: float = 0,
        act_layer: nn.Module = nn.GELU,
        norm_layer: nn.Module = nn.LayerNorm,
        mlp_layer: nn.Module = Mlp,
    ) -> None:
        super().__init__(
            dim,
            num_heads,
            mlp_ratio,
            qkv_bias,
            qk_norm,
            proj_drop,
            attn_drop,
            init_values,
            drop_path,
            act_layer,
            norm_layer,
            mlp_layer,
        )
        self.attn = SelfAttention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
            causal=False,
        )


class CustomVisionTransformer(VisionTransformer):
    def __init__(
        self,
        img_size: int | tuple[int, int] = 224,
        patch_size: int | tuple[int, int] = 16,
        in_chans: int = 3,
        num_classes: int = 1000,
        global_pool: Literal["", "avg", "avgmax", "max", "token", "map"] = "token",
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        init_values: float | None = None,
        class_token: bool = True,
        pos_embed: str = "learn",
        no_embed_class: bool = False,
        reg_tokens: int = 0,
        pre_norm: bool = False,
        fc_norm: bool | None = None,
        dynamic_img_size: bool = False,
        dynamic_img_pad: bool = False,
        drop_rate: float = 0,
        pos_drop_rate: float = 0,
        patch_drop_rate: float = 0,
        proj_drop_rate: float = 0,
        attn_drop_rate: float = 0,
        drop_path_rate: float = 0,
        weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
        fix_init: bool = False,
        embed_layer: Callable = PatchEmbed,
        norm_layer: Optional[LayerType] = None,
        act_layer: Optional[LayerType] = None,
        block_fn: Type[nn.Module] = _CustomBlock,
        mlp_layer: Type[nn.Module] = Mlp,
    ) -> None:
        super().__init__(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            num_classes=num_classes,
            global_pool=global_pool,
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            init_values=init_values,
            class_token=class_token,
            pos_embed=pos_embed,
            no_embed_class=no_embed_class,
            reg_tokens=reg_tokens,
            pre_norm=pre_norm,
            fc_norm=fc_norm,
            dynamic_img_size=dynamic_img_size,
            dynamic_img_pad=dynamic_img_pad,
            drop_rate=drop_rate,
            pos_drop_rate=pos_drop_rate,
            patch_drop_rate=patch_drop_rate,
            proj_drop_rate=proj_drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            weight_init=weight_init,
            fix_init=fix_init,
            embed_layer=embed_layer,
            norm_layer=norm_layer,
            act_layer=act_layer,
            block_fn=block_fn,
            mlp_layer=mlp_layer,
        )

    # TODO hook-function
    def register_hooks(self) -> None:
        self.hook_input = list()

        def hook_fn(module, input, output) -> None:
            if isinstance(module, SelfAttention):
                self.hook_input.append(input[0].detach().cpu())

        for block in self.blocks:
            block.attn.register_forward_hook(hook_fn)

    def clear_hooks(self) -> None:
        self.hook_input = list()


class UniversalViT(VisionTransformer):
    def __init__(
        self,
        img_size: int | tuple[int, int] = 224,
        patch_size: int | tuple[int, int] = 16,
        in_chans: int = 3,
        num_classes: int = 1000,
        global_pool: Literal["", "avg", "avgmax", "max", "token", "map"] = "token",
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        init_values: float | None = None,
        class_token: bool = True,
        pos_embed: str = "learn",
        no_embed_class: bool = False,
        reg_tokens: int = 0,
        pre_norm: bool = False,
        fc_norm: bool | None = None,
        dynamic_img_size: bool = False,
        dynamic_img_pad: bool = False,
        drop_rate: float = 0,
        pos_drop_rate: float = 0,
        patch_drop_rate: float = 0,
        proj_drop_rate: float = 0,
        attn_drop_rate: float = 0,
        drop_path_rate: float = 0,
        weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
        fix_init: bool = False,
        embed_layer: Callable = PatchEmbed,
        norm_layer: Optional[LayerType] = None,
        act_layer: Optional[LayerType] = None,
        block_fn: Type[nn.Module] = _CustomBlock,
        mlp_layer: Type[nn.Module] = Mlp,
    ):
        super().__init__(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            num_classes=num_classes,
            global_pool=global_pool,
            embed_dim=embed_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            init_values=init_values,
            class_token=class_token,
            pos_embed=pos_embed,
            no_embed_class=no_embed_class,
            reg_tokens=reg_tokens,
            pre_norm=pre_norm,
            fc_norm=fc_norm,
            dynamic_img_size=dynamic_img_size,
            dynamic_img_pad=dynamic_img_pad,
            drop_rate=drop_rate,
            pos_drop_rate=pos_drop_rate,
            patch_drop_rate=patch_drop_rate,
            proj_drop_rate=proj_drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            weight_init=weight_init,
            fix_init=fix_init,
            embed_layer=embed_layer,
            norm_layer=norm_layer,
            act_layer=act_layer,
            block_fn=block_fn,
            mlp_layer=mlp_layer,
        )

        norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
        act_layer = get_act_layer(act_layer) or nn.GELU

        self.layer = block_fn(
            dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            proj_drop=proj_drop_rate,
            attn_drop=attn_drop_rate,
            drop_path=drop_path_rate,
            norm_layer=norm_layer,
            act_layer=act_layer,
            mlp_layer=mlp_layer,
        )
        self.blocks = nn.ModuleList([self.layer for _ in range(depth)])

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)

        for module in self.blocks:
            x = module(x=x)

        x = self.norm(x)

        return x

    def forward(
        self,
        x: torch.Tensor,
    ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x

    def register_hooks(self) -> None:
        self.hook_input = list()
        self.hook_count = 0

        def hook_fn(module, input, output) -> None:
            if isinstance(module, SelfAttention) and self.hook_count == 0:
                self.hook_input.append(input[0].detach().cpu())
                self.hook_count += 1

        self.layer.attn.register_forward_hook(hook_fn)

    def clear_hooks(self) -> None:
        self.hook_input = list()
        self.hook_count = 0


if __name__ == "__main__":
    model = UniversalViT(
        num_classes=100,
    )
    import torch

    dummy = torch.randn(1, 3, 224, 224)
    from torchinfo import summary

    summary(model, input_data=dummy, depth=4)

    # model.register_hooks()
    # model(dummy)
    # print(len(model.hook_input))
